import os
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.insert(0, '/media/tan/_dde_data/projects/ginkgo_vfife/scripts')
from visualization import ResultFileSystem, TimeHistory


def trajectory(t, m, c, g=-9.8, v0=10):
    """ moving trajectory  horizontal projectile motion

    Args:
        t (1d-ndarray): time series, s
        m (float): mass, kg
        c (float)): damping coefficient
        g (float, optional): acceleartion of gravity. Defaults to 9.8.
        v0 (float, optional): initial velocity. Defaults to 10.

    Returns:
        (x, vx, ax, y, vy, ay)
    """
    if abs(c) > 1e-6:
        x = m * v0 / c * (1 - np.exp(-c * t / m))
        vx = v0 * np.exp(-c * t / m)
        ax = -c / m * v0 * np.exp(-c * t / m)

        y = m * g / c * t + m**2 * g / (c**2) * (np.exp(-c * t / m) -1)
        vy = m * g / c * (1 - np.exp(-c * t / m))
        ay = g * np.exp(-c * t / m)
    else:
        x = v0 * t
        vx = v0 * np.ones_like(t)
        ax = np.zeros_like(t)

        y = 0.5 * g * t**2
        vy = g * t
        ay = g * np.ones_like(t)

    return (x, vx, ax, y, vy, ay)


def compare(data, keys, item, figname):
    """ compare results of theroy and openVFIFE

    Args:
        data (dict): {"theory": {"c": [t, x, y, z]}
                      "vfife": {"c": [t, x, y, z]}}
        keys (list): ["c1", "c2", ...]
        item (str): "x", "y"
    """
    ind = 1 if item == "x" else 2
    theory = data["theory"]
    vfife = data["vfife"]
    figsize = np.array([8, 6]) / 2.54
    fig, ax = plt.subplots(figsize=figsize, tight_layout=True)
    cnt = 0
    for key in keys:
        l1, = ax.plot(theory[key][:,0], theory[key][:,ind], c="black", lw=1.5)
        l2, = ax.plot(vfife[key][:,0], vfife[key][:,ind], c="red", lw=1.5,
                     ls="dashed")

        ax.text(vfife[key][-10 - cnt*10,0], vfife[key][-10 - cnt*10,ind],
                key,fontsize=8, ha="center", va="center", c="blue")
        cnt += 1
    ax.legend([l1, l2], ["theory", "openVFIFE"], loc="best")
    plt.savefig(figname, dpi=300)
    plt.close(fig)


if __name__ == "__main__":
    mass = 10
    clist = [0.2 * i * mass for i in range(6)]
    dt = 0.1
    t = np.arange(0, 10+dt, dt)
    keys = ["%.1f" %(i/mass) for i in clist]

    # data container
    displace = {"vfife":{}, "theory":{}}
    velocity = {"vfife":{}, "theory":{}}
    accelerate = {"vfife":{}, "theory":{}}

    # load data from openVFIFE result
    dir = "/home/tan/Desktop/test/vfife_examples/example1/"
    for key in keys:
        fname = os.path.join(dir, "c" + key)
        print(fname)
        file = ResultFileSystem(fname)
        th = TimeHistory(file)
        displace["vfife"][key] = th.extract_particle_motion(1, "displace")
        velocity["vfife"][key] = th.extract_particle_motion(1, "velocity")
        accelerate["vfife"][key] = th.extract_particle_motion(1, "accelerate")[1:,:]
        # drop out t = 0

    # theory results
    for i in range(len(clist)):
        x, vx, ax, y, vy, ay = trajectory(t, mass, clist[i])

        displace["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
                                      x.reshape(-1,1), y.reshape(-1,1)))
        velocity["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
                                      vx.reshape(-1,1), vy.reshape(-1,1)))
        accelerate["theory"][keys[i]] = np.hstack((t.reshape(-1,1),
                                        ax.reshape(-1,1), ay.reshape(-1,1)))

    print(displace)
    print(accelerate)

    # file = ResultFileSystem(dir)
    # th = TimeHistory(file)
    # openVFIFE_x = th.extract_particle_motion(1, "displace")

    # fig, axes = plt.subplots()
    # axes.plot(t, y, c="black", lw=2)
    # axes.plot(openVFIFE_x[:,0], openVFIFE_x[:,2], c="red", lw=2)
    # axes.text(openVFIFE_x[-10,0], openVFIFE_x[-10,2], "c=0.2", fontsize=12)
    # plt.show()
    # plt.close()

    cwd = os.getcwd()
    figname = os.path.join(cwd, "dx.svg")
    compare(displace, keys, "x", figname)
    figname = os.path.join(cwd, "dy.svg")
    compare(displace, keys, "y", figname)

    figname = os.path.join(cwd, "vx.svg")
    compare(velocity, keys, "x", figname)
    figname = os.path.join(cwd, "vy.svg")
    compare(velocity, keys, "y", figname)

    figname = os.path.join(cwd, "ax.svg")
    compare(accelerate, keys, "x", figname)
    figname = os.path.join(cwd, "ay.svg")
    compare(accelerate, keys, "y", figname)